# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at https://mozilla.org/MPL/2.0/.
# Copyright (c) 2015 Mozilla Corporation

import hashlib

from mozdef_util.utilities.logger import logger


class message(object):
    MSG_VERSION_1 = 1
    MSG_VERSION_2 = 2

    class version_handler(object):
        def __init__(self, ver, validate, calcid):
            self.version = ver
            self.validate = validate
            self.calculate_id = calcid

    def __init__(self):
        self.registration = ['vulnerability']
        self.priority = 20
        self.handler_v1 = self.version_handler(
            self.MSG_VERSION_1,
            self.validate_v1,
            self.calculate_id_v1)
        self.handler_v2 = self.version_handler(
            self.MSG_VERSION_2,
            self.validate_v2,
            self.calculate_id_v2)

    def get_handler(self, message):
        if 'version' not in message:
            return self.handler_v1
        if int(message['version']) == self.MSG_VERSION_2:
            return self.handler_v2
        return None

    def validate_v1(self, message):
        for k in ['utctimestamp', 'description', 'vuln', 'asset',
                  'sourcename']:
            if k not in message:
                return False
        for k in ['assetid', 'ipv4address', 'hostname', 'macaddress']:
            if k not in message['asset']:
                return False
        for k in ['status', 'vulnid', 'title', 'discovery_time', 'age_days',
                  'known_malware', 'known_exploits', 'cvss', 'cves']:
            if k not in message['vuln']:
                return False
        return True

    def validate_v2(self, message):
        for k in ['utctimestamp', 'description', 'asset', 'sourcename', 'zone']:
            if k not in message:
                return False
        for k in ['hostname', 'ipaddress']:
            if k not in message['asset']:
                return False
        if message['zone'] == '' or message['sourcename'] == '' or \
                message['asset']['ipaddress'] == '' or message['asset']['hostname'] == '':
                return False
        return True

    def calculate_id_v1(self, message):
        s = '{0}|{1}|{2}'.format(
            message['asset']['assetid'],
            message['vuln']['vulnid'], message['sourcename'])
        return hashlib.md5(s.encode()).hexdigest()

    def calculate_id_v2(self, message):
        s = '{0}|{1}|{2}|{3}'.format(
            message['zone'],
            message['sourcename'], message['asset']['hostname'],
            message['asset']['ipaddress'])
        return hashlib.md5(s.encode()).hexdigest()

    def onMessage(self, message, metadata):
        if 'type' not in message or message['type'] != 'vulnerability':
            return (message, metadata)
        handler = self.get_handler(message)
        if handler is None:
            return (None, None)
        if not handler.validate(message):
            logger.error('Invalid format for vulnerability {0}'.format(message))
            return (None, None)
        metadata['id'] = handler.calculate_id(message)
        message['type'] = 'vulnerability_state'
        metadata['index'] = 'vulnerabilities'
        return (message, metadata)
